package archive

import (
	"archive/zip"
	"context"
	"gitee.com/Luna-CY/hui-hui/internal/util/goroutine"
	"github.com/cheggaaa/pb/v3"
	"io"
	"io/fs"
	"os"
	"path/filepath"
	"strings"
)

// CompressZipFile 压缩为zip文件，此方法自动处理目录
func CompressZipFile(_ context.Context, from string, to string) error {
	stat, err := os.Stat(from)
	if nil != err {
		return err
	}

	target, err := os.OpenFile(to, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
	if nil != err {
		return err
	}

	defer func() {
		_ = target.Close()
	}()

	var writer = zip.NewWriter(target)
	defer func() {
		_ = writer.Close()
	}()

	if stat.IsDir() {
		var prefix = filepath.Dir(from)

		return filepath.Walk(from, func(path string, info fs.FileInfo, err error) error {
			header, err := zip.FileInfoHeader(info)
			if nil != err {
				return err
			}

			header.Name = strings.TrimPrefix(path, prefix+"/")
			if info.IsDir() {
				header.Name = header.Name + "/"
				if _, err := writer.CreateHeader(header); nil != err {
					return err
				}

				return nil
			}

			header.Method = zip.Deflate
			dst, err := writer.CreateHeader(header)
			if nil != err {
				return err
			}

			// 目录只需要创建头信息
			if info.IsDir() {
				return nil
			}

			src, err := os.Open(path)
			if nil != err {
				return err
			}

			if _, err := io.Copy(dst, src); err != nil {
				return err
			}

			_ = src.Close()

			return nil
		})
	}

	src, err := os.Open(from)
	if nil != err {
		return err
	}

	header, err := zip.FileInfoHeader(stat)
	if nil != err {
		return err
	}

	header.Name = filepath.Base(from)
	header.Method = zip.Deflate

	dst, err := writer.CreateHeader(header)
	if nil != err {
		return err
	}

	if _, err := io.Copy(dst, src); err != nil {
		return err
	}

	return nil
}

// DecompressZipFile 解压zip文件
func DecompressZipFile(ctx context.Context, file *os.File, size int64, target string, callback func(total int64, current int64)) error {
	var bar = pb.Full.Start64(size)
	var proxy = NewProxyReaderAt(file, bar, make(chan int64, 100))

	defer func() {
		bar.Finish()
		_ = proxy.Close()
	}()

	var reader, err = zip.NewReader(proxy, size)
	if nil != err {
		return err
	}

	goroutine.Go(func() {
		var current int64
		for n := range proxy.ch {
			current += n
			if nil != callback {
				callback(size, current)
			}
		}
	})

	for _, hdr := range reader.File {
		if nil != ctx.Err() {
			return ctx.Err()
		}

		if err := os.MkdirAll(filepath.Join(target, filepath.Dir(hdr.Name)), 0755); nil != err {
			return err
		}

		if hdr.FileInfo().IsDir() {
			continue
		}

		ft, err := os.OpenFile(filepath.Join(target, hdr.Name), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, hdr.FileInfo().Mode())
		if nil != err {
			return err
		}

		fs, err := hdr.Open()
		if nil != err {
			return err
		}

		if _, err := io.Copy(ft, fs); nil != err {
			return err
		}

		_ = ft.Close()
		_ = fs.Close()
	}

	return nil
}
